from distutils.util import strtobool
import sys
from functools import reduce
from math import ceil
from typing import Tuple, List, Optional
import _pickle as pkl
import pickle
import numpy as np
import os
from tqdm import tqdm
import argparse

from Utils import logger, get_logger
from Utils.Constants import FileNamesConstants as Names
from FeatureMap.FeatureMapCreator import FeatureMapGenerator


PARTIAL_SINGLE_LAYER_WEIGHTS_STATS = ('mean', 'variance', 'median', 'std', 'max', 'min', 'covariance', 'skewness',
                                      'kurtosis', 'q-th_percentile', 'L1_norm', 'L2_norm')


def _get_mini_map_name(path_or_name):
    return os.path.basename(path_or_name).split('.')[0].replace('_stats', '') + '_' + Names.MINIMIZED_MAP + '.pkl'


def create_and_save_single_map(file_path: str, out_folder: str, parts: int, nan_th: float,
                               is_weights: bool, is_single_layer: bool, stats_to_use: Tuple[str]):
    """
    Create and save single feature map into a fast pickle format. If map exists or has nan values it will not be created.
    :param file_path: full path to stats file
    :param out_folder: where to save the map
    :param parts: Number of parts to split the feature map while saving to pickle to avoid memory problems (Fewer parts
                  means faster loading, but can cause out of memory error)
    :param nan_th: threshold for nan percentage in data.
    :param is_weights:
    :param is_single_layer:
    :param stats_to_use:
    :return:
    """
    file_name = _get_mini_map_name(file_path)
    out_file_path = os.path.join(out_folder, file_name)
    if os.path.exists(out_file_path):
        logger().force_log_and_print('FeatureMapDataCreator::create_and_save_single_map',
                                     f'File: {out_file_path} already exists - skipping')
    else:
        logger().force_log_and_print('FeatureMapDataCreator::create_and_save_single_map',
                                     f'Creating from file: {file_path}  to {out_file_path}')
        gen = FeatureMapGenerator.create(is_weights=is_weights, is_single_layer=is_single_layer,
                                         stats_to_use=stats_to_use, file_name_or_path=file_path)
        all_data = np.array(list(gen))
        nan_count = np.sum(np.isnan(all_data))
        inf_count = np.sum(np.isinf(all_data))
        data_size = reduce(lambda x, y: x*y, all_data.shape)
        if nan_count > data_size * nan_th:
            logger().force_log_and_print('FeatureMapDataCreator::create_and_save_single_map',
                                         f'File: {file_path} has too many nan values: {nan_count}, '
                                         f'for data with size: {data_size} -- skipping creation of this file')
        if inf_count > data_size * nan_th:
            logger().force_log_and_print('FeatureMapDataCreator::create_and_save_single_map',
                                         f'File: {file_path} has to many inf values: {inf_count}, '
                                         f'for data with size: {data_size} -- skipping creation of this file')
        else:
            part_size = ceil(all_data.shape[0]/parts)
            with open(out_file_path, 'wb') as file:
                for i in range(parts):
                    pkl.dump(all_data[i*part_size: (i+1)*part_size], file, protocol=pickle.HIGHEST_PROTOCOL)

            logger().force_log_and_print('FeatureMapDataCreator::create_and_save_single_map',
                                         f'Done creating: {out_file_path}')


def _get_chunks(total_workers: int, files: List):
    step_size = ceil(len(files)/total_workers)
    if step_size == 0:
        step_size = 1
    chunks = list()
    for i in range(0, len(files), step_size):
        chunks.append(files[i: i+step_size])
    return chunks


def create_maps_from_folders(total_workers: int, out_folder: str, inputs_folder: str, parts: int, nan_th: float,
                             is_weights: bool, is_single_layer: bool, stats_to_use: Tuple[str],
                             map_files_to_create: Optional[List[str]] = None):
    """
    Create partial feature maps and save them from all relevant stats files in the inputs_folder.
    This function supports slurm job array to split workload.
    :param total_workers: slurm number of jobs
    :param out_folder: folder for saving all feature maps
    :param inputs_folder: stats folder
    :param parts: Number of parts to split the feature map while saving to pickle to avoid memory problems (Fewer parts
                  means faster loading, but can cause out of memory error)
    :param nan_th: threshold for nan percentage in data.
    :param is_weights:
    :param is_single_layer:
    :param stats_to_use: stats to be used for feature map
    :param map_files_to_create: select specific files to create
    :return:
    """
    logger().log('FeatureMapDataCreator::folder_maps_creator', locals())
    if map_files_to_create is None:
        files_to_convert = os.listdir(inputs_folder)
        if is_weights:
            files_to_convert = sorted(list(filter(lambda name: Names.WEIGHTS_STATS in name, files_to_convert)))
        else:
            files_to_convert = sorted(list(filter(lambda name: Names.GRADIENTS_STATS in name, files_to_convert)))
        files_to_convert = [os.path.join(inputs_folder, curr_file) for curr_file in files_to_convert]
    else:
        files_to_convert = map_files_to_create
    logger().force_log_and_print('FeatureMapDataCreator::folder_maps_creator', f'Will minimize files: {files_to_convert}\n')

    slurm_id = int(os.environ.get('SLURM_ARRAY_TASK_ID', -1))
    if slurm_id == -1 or total_workers == -1:
        worker_files = files_to_convert
    else:
        files_to_convert = _get_chunks(total_workers, files_to_convert)
        if slurm_id >= len(files_to_convert) or len(files_to_convert[slurm_id]) == 0:
            logger().force_log_and_print('FeatureMapDataCreator::folder_maps_creator',
                                         f'Current job has no files assigned - slurm idx: {slurm_id}\n'
                                         f'Workers files: {files_to_convert}')
            return
        else:
            worker_files = files_to_convert[slurm_id]

    logger().log('FeatureMapDataCreator::folder_maps_creator', f'Current worker: ', slurm_id, 'works on: ', worker_files)
    for curr_file in tqdm(worker_files):
        create_and_save_single_map(file_path=curr_file, out_folder=out_folder, parts=parts, nan_th=nan_th,
                                   is_weights=is_weights, is_single_layer=is_single_layer, stats_to_use=stats_to_use)

    logger().force_log_and_print('FeatureMapDataCreator::folder_maps_creator', f'Worker: {slurm_id} Finished')


def parse_args():
    parser = argparse.ArgumentParser(description='Minimize stats files to small feature maps')
    parser.add_argument('-i', '--inputs_folder', type=str, help='stats folder')
    parser.add_argument('-o', '--out_folder', type=str, help='output folder for all feature maps')
    parser.add_argument('-p', '--parts', type=int, help='number of parts to split data when saving', default=1)
    parser.add_argument('-nan', '--nan_th', type=float, help='percentage of nan that is allowed in map', default=0)
    parser.add_argument('-slurm', '--slurm', type=int, help='number of jobs in slurm array', default=-1)
    parser.add_argument('-weights', '--is_weights', type=str, help='generate weights or gradients', default='1')
    parser.add_argument('-single_layer', '--is_single_layer', type=str, default='1',
                        help='generate map for single or diff layers')
    parser.add_argument('-stats', '--stats_to_use', type=str, nargs='+', default=PARTIAL_SINGLE_LAYER_WEIGHTS_STATS,
                        help='List of stats to use for creating the feature map')

    parsed_args = parser.parse_args(sys.argv[1:])
    return parsed_args


if __name__ == '__main__':
    get_logger(os.path.basename(__file__).split('.')[0])
    args_ = parse_args()
    create_maps_from_folders(total_workers=args_.slurm, inputs_folder=args_.inputs_folder, out_folder=args_.out_folder,
                             parts=args_.parts, nan_th=args_.nan_th, is_weights=bool(strtobool(args_.is_weights)),
                             is_single_layer=bool(strtobool(args_.is_single_layer)), stats_to_use=args_.stats_to_use,
                             map_files_to_create=None)
